{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1f07d4d",
   "metadata": {},
   "source": [
    "# Notebook Launcher examples\n",
    "\n",
    "A quick(ish) test of most of the main applications people use, taken from `fastbook`, and ran with Accelerate across multiple GPUs through `notebook_launcher`"
   ]
  },
  {
   "cell_type": "raw",
   "id": "9d67d44e",
   "metadata": {},
   "source": [
    "---\n",
    "skip_exec: true\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0f9e33c-e06f-42de-8a71-63e520fb63c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.all import *\n",
    "from fastai.text.all import *\n",
    "from fastai.tabular.all import *\n",
    "from fastai.collab import *\n",
    "\n",
    "from accelerate import notebook_launcher\n",
    "from fastai.distributed import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5499ff72",
   "metadata": {},
   "source": [
    ":::{.callout-important}\n",
    "\n",
    "Before running, ensure that Accelerate has been configured through either `accelerate config` in the command line or by running `write_basic_config`\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de391d25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from accelerate.utils import write_basic_config\n",
    "# write_basic_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0873b97-ad92-432a-b9e3-c1194cef57a9",
   "metadata": {},
   "source": [
    "### Image Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a936abf3-1022-44ca-8816-3204e6fc0fc7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>error_rate</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.258557</td>\n",
       "      <td>0.024234</td>\n",
       "      <td>0.008119</td>\n",
       "      <td>00:13</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>error_rate</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.039532</td>\n",
       "      <td>0.019273</td>\n",
       "      <td>0.005413</td>\n",
       "      <td>00:15</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.PETS)/'images'\n",
    "\n",
    "def train():\n",
    "    dls = ImageDataLoaders.from_name_func(\n",
    "        path, get_image_files(path), valid_pct=0.2,\n",
    "        label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))\n",
    "    learn = vision_learner(dls, resnet34, metrics=error_rate).to_fp16()\n",
    "    with learn.distrib_ctx(in_notebook=True, sync_bn=False):\n",
    "        learn.fine_tune(1)\n",
    "\n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecf6a12d-ef5a-4467-8c25-3b82d8ce9e17",
   "metadata": {},
   "source": [
    "### Image Segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b84af35-eb81-4a84-955a-805b2ede7673",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py:1142: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  ret = func(*args, **kwargs)\n",
      "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py:1142: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  ret = func(*args, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>4.339367</td>\n",
       "      <td>2.756470</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.405208</td>\n",
       "      <td>2.095044</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.113619</td>\n",
       "      <td>1.692979</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.931254</td>\n",
       "      <td>1.333691</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.753757</td>\n",
       "      <td>1.187579</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.612463</td>\n",
       "      <td>1.097649</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>1.493950</td>\n",
       "      <td>0.992424</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>1.393139</td>\n",
       "      <td>0.949843</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>1.312021</td>\n",
       "      <td>0.942510</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.CAMVID_TINY)\n",
    "\n",
    "def train():\n",
    "    dls = SegmentationDataLoaders.from_label_func(\n",
    "        path, bs=8, fnames = get_image_files(path/\"images\"),\n",
    "        label_func = lambda o: path/'labels'/f'{o.stem}_P{o.suffix}',\n",
    "        codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
    "    )\n",
    "    learn = unet_learner(dls, resnet34)\n",
    "    with learn.distrib_ctx(in_notebook=True, sync_bn=False):\n",
    "        learn.fine_tune(8)\n",
    "        \n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c183e8f-3368-4637-9f3b-53586f21d613",
   "metadata": {},
   "source": [
    "### Text Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "789bf920-f035-4a2f-b91c-4ccca726c468",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.683830</td>\n",
       "      <td>0.640674</td>\n",
       "      <td>0.710000</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.505055</td>\n",
       "      <td>0.618315</td>\n",
       "      <td>0.650000</td>\n",
       "      <td>00:10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.433232</td>\n",
       "      <td>0.522627</td>\n",
       "      <td>0.785000</td>\n",
       "      <td>00:11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.391711</td>\n",
       "      <td>0.460229</td>\n",
       "      <td>0.810000</td>\n",
       "      <td>00:11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.347983</td>\n",
       "      <td>0.450882</td>\n",
       "      <td>0.805000</td>\n",
       "      <td>00:11</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.IMDB_SAMPLE)\n",
    "df = pd.read_csv(path/'texts.csv')\n",
    "\n",
    "def train():\n",
    "    imdb_clas = DataBlock(blocks=(TextBlock.from_df('text', seq_len=72), CategoryBlock),\n",
    "                      get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter())\n",
    "    dls = imdb_clas.dataloaders(df, bs=64)\n",
    "    learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))\n",
    "    with learn.distrib_ctx(in_notebook=True):\n",
    "        learn.fine_tune(4, 1e-2)\n",
    "        \n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc0a3984-f078-4b38-8d4a-7133ab7e8e65",
   "metadata": {},
   "source": [
    "### Tabular"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fe868f-b555-4b4e-b367-14e3ce9c6fe7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.436493</td>\n",
       "      <td>0.383866</td>\n",
       "      <td>0.832463</td>\n",
       "      <td>00:03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.359663</td>\n",
       "      <td>0.352825</td>\n",
       "      <td>0.837224</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.349231</td>\n",
       "      <td>0.350312</td>\n",
       "      <td>0.839988</td>\n",
       "      <td>00:02</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')\n",
    "\n",
    "\n",
    "def train():\n",
    "    dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=\"salary\",\n",
    "            cat_names = ['workclass', 'education', 'marital-status', 'occupation',\n",
    "                         'relationship', 'race'],\n",
    "            cont_names = ['age', 'fnlwgt', 'education-num'],\n",
    "            procs = [Categorify, FillMissing, Normalize])\n",
    "\n",
    "    learn = tabular_learner(dls, metrics=accuracy)\n",
    "    with learn.distrib_ctx(in_notebook=True):\n",
    "        learn.fit_one_cycle(3)\n",
    "        \n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "984f091d-eb13-4f4a-9834-b737cf1b0f76",
   "metadata": {},
   "source": [
    "### Collab Filtering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6701d4b-e3df-46ea-bc6d-c66ccf1f754e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.553747</td>\n",
       "      <td>1.430443</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.484851</td>\n",
       "      <td>1.394805</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.424410</td>\n",
       "      <td>1.255329</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.277911</td>\n",
       "      <td>1.028214</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.099660</td>\n",
       "      <td>0.882485</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.969005</td>\n",
       "      <td>0.835191</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.894699</td>\n",
       "      <td>0.828167</td>\n",
       "      <td>00:00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.ML_SAMPLE)\n",
    "df = pd.read_csv(path/'ratings.csv')\n",
    "\n",
    "def train():\n",
    "    dls = CollabDataLoaders.from_df(df)\n",
    "    learn = collab_learner(dls, y_range=(0.5,5.5))\n",
    "    with learn.distrib_ctx(in_notebook=True):\n",
    "        learn.fine_tune(6)\n",
    "        \n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84513855-0bce-43dc-b353-865af31588ed",
   "metadata": {},
   "source": [
    "### Keypoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a262a140-344a-4bc5-b6fc-87fcc348b480",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Launching training on 2 GPUs.\n",
      "Training Learner...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.247702</td>\n",
       "      <td>0.066427</td>\n",
       "      <td>00:47</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* Turns off some styling */\n",
       "    progress {\n",
       "        /* gets rid of default border in Firefox and Opera. */\n",
       "        border: none;\n",
       "        /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "        background-size: auto;\n",
       "    }\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #F44336;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.052143</td>\n",
       "      <td>0.007451</td>\n",
       "      <td>00:55</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "path = untar_data(URLs.BIWI_HEAD_POSE)\n",
    "def img2pose(x): return Path(f'{str(x)[:-7]}pose.txt')\n",
    "def get_ctr(f):\n",
    "    ctr = np.genfromtxt(img2pose(f), skip_header=3)\n",
    "    c1 = ctr[0] * cal[0][0]/ctr[2] + cal[0][2]\n",
    "    c2 = ctr[1] * cal[1][1]/ctr[2] + cal[1][2]\n",
    "    return tensor([c1,c2])\n",
    "\n",
    "img_files = get_image_files(path)\n",
    "cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)\n",
    "\n",
    "\n",
    "def train():\n",
    "    biwi = DataBlock(\n",
    "            blocks=(ImageBlock, PointBlock),\n",
    "            get_items=get_image_files,\n",
    "            get_y=get_ctr,\n",
    "            splitter=FuncSplitter(lambda o: o.parent.name=='13'),\n",
    "            batch_tfms=[*aug_transforms(size=(240,320)), \n",
    "                        Normalize.from_stats(*imagenet_stats)])\n",
    "    dls = biwi.dataloaders(path)\n",
    "    learn = vision_learner(dls, resnet18, y_range=(-1,1))\n",
    "    with learn.distrib_ctx(in_notebook=True, sync_bn=False):\n",
    "        learn.fine_tune(1)\n",
    "        \n",
    "notebook_launcher(train, num_processes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3b42536",
   "metadata": {},
   "source": [
    "## fin -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "563206bd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
